/*
 * Copyright 2018 The Cartographer Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "cartographer/mapping/internal/2d/scan_matching/occupied_space_cost_function_2d.h"
#include "cartographer/mapping/probability_values.h"
#include "ceres/cubic_interpolation.h"
#include "cartographer/rcm/match/helper.h"

namespace cartographer
{
    namespace mapping
    {
        namespace scan_matching
        {
            namespace
            {
                // Computes a cost for matching the 'point_cloud' to the 'grid' with
                // a 'pose'. The cost increases with poorer correspondence of the grid and the
                // point observation (e.g. points falling into less occupied space).
                class OccupiedSpaceCostFunction2D
                {
                public:
                    OccupiedSpaceCostFunction2D(const double scaling_factor, const sensor::PointCloud &point_cloud, const Grid2D &grid)
                        : scaling_factor_(scaling_factor)
                        , point_cloud_(point_cloud)
                        , grid_(grid)
                    {
                    }

                    template<typename T> bool operator()(const T *const pose, T *residual) const
                    {
                        Eigen::Matrix<T, 2, 1> translation(pose[0], pose[1]);
                        Eigen::Rotation2D<T> rotation(pose[2]);
                        Eigen::Matrix<T, 2, 2> rotation_matrix = rotation.toRotationMatrix();
                        Eigen::Matrix<T, 3, 3> transform;
                        transform << rotation_matrix, translation, T(0.), T(0.), T(1.);

                        const GridArrayAdapter adapter(grid_);
                        ceres::BiCubicInterpolator<GridArrayAdapter> interpolator(adapter);
                        const MapLimits &limits = grid_.limits();
                        double totalScore = 0.0;
                        size_t pointSize = point_cloud_.size();
                        for (size_t i = 0; i < pointSize; ++i)
                        {
                            // Note that this is a 2D point. The third component is a scaling factor.
                            const Eigen::Matrix<T, 3, 1> point((T(point_cloud_[i].position.x())), (T(point_cloud_[i].position.y())), T(1.));
                            const Eigen::Matrix<T, 3, 1> world = transform * point;
                            interpolator.Evaluate((limits.max().x() - world[0]) / limits.resolution() - 0.5 + static_cast<double>(kPadding),
                                                  (limits.max().y() - world[1]) / limits.resolution() - 0.5 + static_cast<double>(kPadding),
                                                  &residual[i]);
                            residual[i] = scaling_factor_ * residual[i];
                            // totalScore += T(residual[i]);
                        }
                        MatchSupervisor::GetInstance().residualProb_ = totalScore / double(pointSize);
                        return true;
                    }

                private:
                    static constexpr int kPadding = INT_MAX / 4;
                    class GridArrayAdapter
                    {
                    public:
                        enum
                        {
                            DATA_DIMENSION = 1
                        };

                        explicit GridArrayAdapter(const Grid2D &grid)
                            : grid_(grid)
                        {
                        }

                        void GetValue(const int row, const int column, double *const value) const
                        {
                            if (row < kPadding || column < kPadding || row >= NumRows() - kPadding || column >= NumCols() - kPadding)
                            {
                                *value = kMaxCorrespondenceCost;
                            }
                            else
                            {
                                *value =
                                    static_cast<double>(grid_.GetCorrespondenceCost(Eigen::Array2i(column - kPadding, row - kPadding)));
                            }
                        }

                        int NumRows() const
                        {
                            return grid_.limits().cell_limits().num_y_cells + 2 * kPadding;
                        }

                        int NumCols() const
                        {
                            return grid_.limits().cell_limits().num_x_cells + 2 * kPadding;
                        }

                    private:
                        const Grid2D &grid_;
                    };

                    OccupiedSpaceCostFunction2D(const OccupiedSpaceCostFunction2D &) = delete;
                    OccupiedSpaceCostFunction2D &operator=(const OccupiedSpaceCostFunction2D &) = delete;

                    const double scaling_factor_;
                    const sensor::PointCloud &point_cloud_;
                    const Grid2D &grid_;
                };

            } // namespace

            ceres::CostFunction *CreateOccupiedSpaceCostFunction2D(const double scaling_factor, const sensor::PointCloud &point_cloud,
                                                                   const Grid2D &grid)
            {
                return new ceres::AutoDiffCostFunction<OccupiedSpaceCostFunction2D, ceres::DYNAMIC /* residuals */, 3 /* pose variables */>(
                    new OccupiedSpaceCostFunction2D(scaling_factor, point_cloud, grid), point_cloud.size());
            }
        } // namespace scan_matching
    }     // namespace mapping
} // namespace cartographer
